package fetcher

import (
	"time"
	"fmt"
	"cve-db/util"
	log "github.com/sirupsen/logrus"
	"compress/gzip"
	"cve-db/commonerr"
	"net/http"
	"net/url"
	"crypto/tls"
	"github.com/htcat/htcat"
	"bytes"
	"io/ioutil"
	"cve-db/config"
	"cve-db/models"
	"strings"
	"github.com/knqyf263/go-cpe/naming"
	"github.com/knqyf263/go-cpe/common"
)

// fetch的目标信息结构体
type FetchRequest struct {
	Year int
	URL  string
	GZIP bool
}

// fetch的目标结果结构体
type FetchResult struct {
	Year int
	URL  string
	Body []byte
}

func FetchFeedDatas(reqs []FetchRequest) (results []FetchResult, err error) {
	reqChan := make(chan FetchRequest, len(reqs))
	resChan := make(chan FetchResult, len(reqs))
	errChan := make(chan error, len(reqs))
	defer close(reqChan)
	defer close(resChan)
	defer close(errChan)

	for _, r := range reqs {
		log.Infof("Fetching... %s", r.URL)
	}

	go func() {
		for _, r := range reqs {
			reqChan <- r
		}
	}()

	// 多路复用
	concurrency := len(reqs)
	tasks := util.GenWorkers(concurrency)
	for range reqs {
		tasks <- func() {
			select {
			case req := <-reqChan:
				body, err := downloadFeedData(req, 20/len(reqs))
				if err != nil {
					errChan <- err
					return
				}
				resChan <- FetchResult{
					Year: req.Year,
					URL:  req.URL,
					Body: body,
				}
			}
			return
		}
	}

	errs := []error{}
	timeout := time.After(10 * 60 * time.Second)
	for range reqs {
		select {
		case res := <-resChan:
			results = append(results, res)
			log.Infof("Fetched... %s", res.URL)
		case err := <-errChan:
			errs = append(errs, err)
		case <-timeout:
			return results, fmt.Errorf("Timeout Fetching")
		}
	}
	if 0 < len(errs) {
		return results, fmt.Errorf("%s", errs)
	}
	return results, nil
}

func buildHtppClient() (*http.Client, error){
	var proxyURL *url.URL
	var err error
	httpClient := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
		},
	}
	if config.Conf.HTTPProxy != "" {
		if proxyURL, err = url.Parse(config.Conf.HTTPProxy); err != nil {
			log.WithError(err).WithFields(log.Fields{
				"PROXYURL": config.Conf.HTTPProxy}).Error("proxy url parse fail.")
			return nil, commonerr.ErrCouldNotParse
		}
		httpClient = &http.Client{
			Transport: &http.Transport{
				Proxy:           http.ProxyURL(proxyURL),
				TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
			},
		}
	}
	return httpClient, nil
}

// 下载一个feed数据
func downloadFeedData(req FetchRequest, parallelism int) (body []byte, err error) {
	httpClient, err := buildHtppClient()
	if err != nil {
		log.WithError(err).WithFields(log.Fields{
			"DataFeedName": req.URL}).Error("http build error.")
		return nil, err
	}

	u, err := url.Parse(req.URL)
	if err != nil {
		return nil, commonerr.ErrCouldNotParse
	}

	buf := bytes.Buffer{}
	htc := htcat.New(httpClient, u, parallelism)
	if _, err := htc.WriteTo(&buf); err != nil {
		log.WithError(err).WithFields(log.Fields{
			"DataFeedName": req.URL}).Error("htcat get error.")
		return nil, commonerr.ErrCouldNotDownload
	}

	if req.GZIP {
		reader, err := gzip.NewReader(bytes.NewReader(buf.Bytes()))
		defer reader.Close()
		if err != nil {
			log.WithError(err).WithFields(log.Fields{
				"URL": req.URL}).Error("Failed to decompress NVD feedfile")
			return nil, commonerr.ErrCouldNotDownload
		}

		bytes, err := ioutil.ReadAll(reader)
		if err != nil {
			log.WithError(err).WithFields(log.Fields{
				"URL": req.URL}).Error("Failed to Read NVD feedfile")
			return nil, commonerr.ErrCouldNotParse
		}
		return bytes, nil
	}

	return buf.Bytes(), nil
}

// ParseCpeURI解析cpe22uri到models.CpeBase
func ParseCpeURI(uri string) (*models.CpeBase, error) {
	var wfn common.WellFormedName
	var err error
	if strings.HasPrefix(uri, "cpe:/") {
		val := strings.TrimPrefix(uri, "cpe:/")
		if strings.Contains(val, "/") {
			uri = "cpe:/" + strings.Replace(val, "/", `\/`, -1)
		}
		wfn, err = naming.UnbindURI(uri)
		if err != nil {
			return nil, err
		}
	} else {
		wfn, err = naming.UnbindFS(uri)
		if err != nil {
			return nil, err
		}
	}

	return &models.CpeBase{
		URI:             naming.BindToURI(wfn),
		FormattedString: naming.BindToFS(wfn),
		WellFormedName:  wfn.String(),
		CpeWFN: models.CpeWFN{
			Part:            fmt.Sprintf("%s", wfn.Get(common.AttributePart)),
			Vendor:          fmt.Sprintf("%s", wfn.Get(common.AttributeVendor)),
			Product:         fmt.Sprintf("%s", wfn.Get(common.AttributeProduct)),
			Version:         fmt.Sprintf("%s", wfn.Get(common.AttributeVersion)),
			Update:          fmt.Sprintf("%s", wfn.Get(common.AttributeUpdate)),
			Edition:         fmt.Sprintf("%s", wfn.Get(common.AttributeEdition)),
			Language:        fmt.Sprintf("%s", wfn.Get(common.AttributeLanguage)),
			SoftwareEdition: fmt.Sprintf("%s", wfn.Get(common.AttributeSwEdition)),
			TargetSW:        fmt.Sprintf("%s", wfn.Get(common.AttributeTargetSw)),
			TargetHW:        fmt.Sprintf("%s", wfn.Get(common.AttributeTargetHw)),
			Other:           fmt.Sprintf("%s", wfn.Get(common.AttributeOther)),
		},
	}, nil
}
